package karma.pool.pool.proxy;

import javassist.*;
import javassist.bytecode.ClassFile;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.lang.reflect.Array;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.*;

@Slf4j
public final class JavassistProxyFactory {
   private static ClassPool classPool;
   private static String genDirectory = "";

   public static void main(String... args) throws Exception {
      classPool = new ClassPool();
      classPool.importPackage("java.sql");
      classPool.appendClassPath(new LoaderClassPath(JavassistProxyFactory.class.getClassLoader()));

      if (args.length > 0) {
         genDirectory = args[0];
      }

      // Cast is not needed for these
      String methodBody = "{ try { return resultSet.method($$); } catch (SQLException e) { throw checkException(e); } }";
      generateProxyClass(Connection.class, ProxyConnection.class.getName(), methodBody);
      generateProxyClass(Statement.class, ProxyStatement.class.getName(), methodBody);
      generateProxyClass(ResultSet.class, ProxyResultSet.class.getName(), methodBody);

      // For these we have to cast the resultSet
      methodBody = "{ try { return ((cast) resultSet).method($$); } catch (SQLException e) { throw checkException(e); } }";
      generateProxyClass(PreparedStatement.class, ProxyPreparedStatement.class.getName(), methodBody);


      modifyProxyFactory();
   }

   private static void modifyProxyFactory() throws NotFoundException, CannotCompileException, IOException {
      log.info("Generating method bodies for com.zaxxer.hikari.proxy.ProxyFactory");

      String packageName = ProxyConnection.class.getPackage().getName();
      CtClass ctClass = classPool.getCtClass(ProxyFactory.class.getName());
      for (CtMethod ctMethod : ctClass.getMethods()) {
         switch (ctMethod.getName()) {
            case "getProxyConnection":
               ctMethod.setBody("{return new " + packageName + ".ProxyConnection($$);}");
               break;
            case "getProxyStatement":
               ctMethod.setBody("{return new " + packageName + ".ProxyStatement($$);}");
               break;
            case "getProxyPreparedStatement":
               ctMethod.setBody("{return new " + packageName + ".ProxyPreparedStatement($$);}");
               break;
            case "getProxyResultSet":
               ctMethod.setBody("{return new " + packageName + ".ProxyResultSet($$);}");
               break;
            default:
               // unhandled method
               break;
         }
      }

      ctClass.writeFile(genDirectory + "target/classes");
   }


   private static <T> void generateProxyClass(Class<T> primaryInterface, String superClassName, String methodBody) throws Exception {
      String newClassName = superClassName.replaceAll("(.+)\\.(\\w+)", "$1.Karma$2");

      CtClass superCt = classPool.getCtClass(superClassName);
      CtClass targetCt = classPool.makeClass(newClassName, superCt);
      targetCt.setModifiers(Modifier.FINAL);

      System.out.println("Generating " + newClassName);

      targetCt.setModifiers(Modifier.PUBLIC);

      // Make a set of method signatures we inherit implementation for, so we don't generate delegates for these
      Set<String> superSigs = new HashSet<>();
      for (CtMethod method : superCt.getMethods()) {
         if ((method.getModifiers() & Modifier.FINAL) == Modifier.FINAL) {
            superSigs.add(method.getName() + method.getSignature());
         }
      }

      Set<String> methods = new HashSet<>();
      for (Class<?> intf : getAllInterfaces(primaryInterface)) {
         CtClass intfCt = classPool.getCtClass(intf.getName());
         targetCt.addInterface(intfCt);
         for (CtMethod intfMethod : intfCt.getDeclaredMethods()) {
            final String signature = intfMethod.getName() + intfMethod.getSignature();

            // don't generate delegates for methods we override
            if (superSigs.contains(signature)) {
               continue;
            }

            // Ignore already added methods that come from other interfaces
            if (methods.contains(signature)) {
               continue;
            }

            // Track what methods we've added
            methods.add(signature);

            // Clone the method we want to inject into
            CtMethod method = CtNewMethod.copy(intfMethod, targetCt, null);

            String modifiedBody = methodBody;

            // If the super-Proxy has concrete methods (non-abstract), transform the call into a simple super.method() call
            CtMethod superMethod = superCt.getMethod(intfMethod.getName(), intfMethod.getSignature());
            if ((superMethod.getModifiers() & Modifier.ABSTRACT) != Modifier.ABSTRACT && !isDefaultMethod(intf, intfMethod)) {
               modifiedBody = modifiedBody.replace("((cast) ", "");
               modifiedBody = modifiedBody.replace("resultSet", "super");
               modifiedBody = modifiedBody.replace("super)", "super");
            }

            modifiedBody = modifiedBody.replace("cast", primaryInterface.getName());

            // Generate a method that simply invokes the same method on the resultSet
            if (isThrowsSqlException(intfMethod)) {
               modifiedBody = modifiedBody.replace("method", method.getName());
            } else {
               modifiedBody = "{ return ((cast) resultSet).method($$); }".replace("method", method.getName()).replace("cast", primaryInterface.getName());
            }

            if (method.getReturnType() == CtClass.voidType) {
               modifiedBody = modifiedBody.replace("return", "");
            }

            method.setBody(modifiedBody);
            targetCt.addMethod(method);
         }
      }

      targetCt.getClassFile().setMajorVersion(ClassFile.JAVA_8);
      targetCt.writeFile(genDirectory + "target/classes");
   }

   private static boolean isThrowsSqlException(CtMethod ctMethod) {
      try {
         for (CtClass ctClass : ctMethod.getExceptionTypes()) {
            if (ctClass.getSimpleName().equals("SQLException")) {
               return true;
            }
         }
      } catch (NotFoundException e) {
         // fall thru
      }

      return false;
   }

   private static boolean isDefaultMethod(Class<?> intf, CtMethod ctMethod) throws Exception {
      List<Class<?>> paramTypes = new ArrayList<>();

      for (CtClass ctClass : ctMethod.getParameterTypes()) {
         paramTypes.add(toJavaClass(ctClass));
      }

      return intf.getDeclaredMethod(ctMethod.getName(), paramTypes.toArray(new Class[0])).toString().contains("default ");
   }

   private static Set<Class<?>> getAllInterfaces(Class<?> clazz) {
      Set<Class<?>> interfaces = new LinkedHashSet<>();
      for (Class<?> intf : clazz.getInterfaces()) {
         if (intf.getInterfaces().length > 0) {
            interfaces.addAll(getAllInterfaces(intf));
         }
         interfaces.add(intf);
      }
      if (clazz.getSuperclass() != null) {
         interfaces.addAll(getAllInterfaces(clazz.getSuperclass()));
      }

      if (clazz.isInterface()) {
         interfaces.add(clazz);
      }

      return interfaces;
   }

   private static Class<?> toJavaClass(CtClass cls) throws Exception {
      if (cls.getName().endsWith("[]")) {
         return Array.newInstance(toJavaClass(cls.getName().replace("[]", "")), 0).getClass();
      } else {
         return toJavaClass(cls.getName());
      }
   }

   private static Class<?> toJavaClass(String cn) throws Exception {
      switch (cn) {
         case "int":
            return int.class;
         case "long":
            return long.class;
         case "short":
            return short.class;
         case "byte":
            return byte.class;
         case "float":
            return float.class;
         case "double":
            return double.class;
         case "boolean":
            return boolean.class;
         case "char":
            return char.class;
         case "void":
            return void.class;
         default:
            return Class.forName(cn);
      }
   }
}
